{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# {class}`~tvm.contrib.msc.pipeline.TorchDynamic`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/media/pc/data/lxw/ai/tvm-book/doc/tutorials/msc\n"
     ]
    }
   ],
   "source": [
    "%cd ..\n",
    "from pathlib import Path\n",
    "\n",
    "temp_dir = Path(\".temp\")\n",
    "temp_dir.mkdir(exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _get_torch_model(name, training=False):\n",
    "    \"\"\"Get model from torch vision\"\"\"\n",
    "\n",
    "    # pylint: disable=import-outside-toplevel\n",
    "    try:\n",
    "        import torchvision\n",
    "\n",
    "        model = getattr(torchvision.models, name)()\n",
    "        if training:\n",
    "            model = model.train()\n",
    "        else:\n",
    "            model = model.eval()\n",
    "        return model\n",
    "    except:  # pylint: disable=bare-except\n",
    "        print(\"please install torchvision package\")\n",
    "        return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tvm.contrib.msc.core import utils as msc_utils\n",
    "\n",
    "def _get_config(model_type, compile_type, inputs, outputs, dynamic=False, atol=1e-1, rtol=1e-1):\n",
    "    \"\"\"Get msc config\"\"\"\n",
    "\n",
    "    path = f'test_pipe_{model_type}_{compile_type}_{\"dynamic\" if dynamic else \"static\"}'\n",
    "    return {\n",
    "        \"workspace\": msc_utils.msc_dir(f\"{temp_dir}/{path}\", keep_history=False),\n",
    "        \"verbose\": \"critical\",\n",
    "        \"model_type\": model_type,\n",
    "        \"inputs\": inputs,\n",
    "        \"outputs\": outputs,\n",
    "        \"dataset\": {\"prepare\": {\"loader\": \"from_random\", \"max_iter\": 5}},\n",
    "        \"prepare\": {\"profile\": {\"benchmark\": {\"repeat\": 10}}},\n",
    "        \"baseline\": {\n",
    "            \"run_type\": model_type,\n",
    "            \"profile\": {\"check\": {\"atol\": atol, \"rtol\": rtol}, \"benchmark\": {\"repeat\": 10}},\n",
    "        },\n",
    "        \"compile\": {\n",
    "            \"run_type\": compile_type,\n",
    "            \"profile\": {\"check\": {\"atol\": atol, \"rtol\": rtol}, \"benchmark\": {\"repeat\": 10}},\n",
    "        },\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/media/pc/data/lxw/envs/anaconda3a/envs/ai/lib/python3.12/site-packages/onnxscript/converter.py:820: FutureWarning: 'onnxscript.values.Op.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.\n",
      "  param_schemas = callee.param_schemas()\n",
      "/media/pc/data/lxw/envs/anaconda3a/envs/ai/lib/python3.12/site-packages/onnxscript/converter.py:820: FutureWarning: 'onnxscript.values.OnnxFunction.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.\n",
      "  param_schemas = callee.param_schemas()\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'success': False, 'info': {'prepare': {'profile': {'jit_0': '46.75 ms @ cpu'}}}, 'duration': {'setup': '0.00 s(0.00%)', 'prepare': '6.19 s(49.31%)', 'parse': '0.09 s(0.68%)', 'total': '12.55 s(100.00%)'}, 'err_msg': 'Pipeline failed: Unsupported function type batch_norm', 'err_info': 'Traceback (most recent call last):\\n  File \"/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/pipeline/pipeline.py\", line 162, in run_pipe\\n    self.parse()\\n  File \"/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/pipeline/pipeline.py\", line 226, in parse\\n    info, report = self._parse()\\n                   ^^^^^^^^^^^^^\\n  File \"/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/pipeline/dynamic.py\", line 157, in _parse\\n    info[name], report[name] = w_ctx[\"worker\"].parse()\\n                               ^^^^^^^^^^^^^^^^^^^^^^^\\n  File \"/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/pipeline/worker.py\", line 320, in parse\\n    self._relax_mod, _ = stage_config[\"parser\"](self._model, **parse_config)\\n                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n  File \"/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/framework/torch/frontend/translate.py\", line 119, in from_torch\\n    relax_mod = from_fx(graph_model, input_info, custom_convert_map=custom_convert_map)\\n                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n  File \"/media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/torch/fx_translator.py\", line 960, in from_fx\\n    return TorchFXImporter().from_fx(\\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^\\n  File \"/media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/torch/fx_translator.py\", line 837, in from_fx\\n    assert (\\nAssertionError: Unsupported function type batch_norm\\n'}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[12:54:57] /media/pc/data/lxw/ai/tvm/src/relax/ir/block_builder.cc:65: Warning: BlockBuilder destroyed with remaining blocks!\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'success': False, 'info': {'prepare': {'profile': {'jit_0': '42.50 ms @ cpu'}}}, 'duration': {'setup': '0.00 s(0.00%)', 'prepare': '4.81 s(49.18%)', 'parse': '0.08 s(0.82%)', 'total': '9.78 s(100.00%)'}, 'err_msg': 'Pipeline failed: Unsupported function type batch_norm', 'err_info': 'Traceback (most recent call last):\\n  File \"/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/pipeline/pipeline.py\", line 162, in run_pipe\\n    self.parse()\\n  File \"/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/pipeline/pipeline.py\", line 226, in parse\\n    info, report = self._parse()\\n                   ^^^^^^^^^^^^^\\n  File \"/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/pipeline/dynamic.py\", line 157, in _parse\\n    info[name], report[name] = w_ctx[\"worker\"].parse()\\n                               ^^^^^^^^^^^^^^^^^^^^^^^\\n  File \"/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/pipeline/worker.py\", line 320, in parse\\n    self._relax_mod, _ = stage_config[\"parser\"](self._model, **parse_config)\\n                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n  File \"/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/framework/torch/frontend/translate.py\", line 119, in from_torch\\n    relax_mod = from_fx(graph_model, input_info, custom_convert_map=custom_convert_map)\\n                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n  File \"/media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/torch/fx_translator.py\", line 960, in from_fx\\n    return TorchFXImporter().from_fx(\\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^\\n  File \"/media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/torch/fx_translator.py\", line 837, in from_fx\\n    assert (\\nAssertionError: Unsupported function type batch_norm\\n'}\n"
     ]
    }
   ],
   "source": [
    "from tvm.contrib.msc.core.utils.namespace import MSCFramework\n",
    "from tvm.contrib.msc.pipeline import TorchDynamic\n",
    "import torch\n",
    "\n",
    "for compile_type in [MSCFramework.TORCH, MSCFramework.TVM]:\n",
    "    torch_model = _get_torch_model(\"resnet50\", False)\n",
    "    if torch.cuda.is_available():\n",
    "        torch_model = torch_model.to(torch.device(\"cuda:0\"))\n",
    "    config = _get_config(\n",
    "        MSCFramework.TORCH,\n",
    "        compile_type,\n",
    "        inputs=[[\"input_0\", [1, 3, 224, 224], \"float32\"]],\n",
    "        outputs=[\"output\"],\n",
    "        dynamic = True,\n",
    "        atol = 1e-1,\n",
    "        rtol = 1e-1,\n",
    "    )\n",
    "    pipeline = TorchDynamic(torch_model, config)\n",
    "    pipeline.run_pipe() # 运行管道\n",
    "    print(pipeline.report) # 打印模型信息\n",
    "    pipeline.destory() # 销毁管道"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
